<center>
    <ul class="downloads">
        <li id="gradientItem" class="active">
            <a href="" id="gradientButton">
                the gradient with
                <strong>
                    PyTorch
                </strong>
            </a>
        </li>
        <li id="varianceItem">
            <a href="" id="varianceButton">
                an estimate of the
                <strong>
                    Variance
                </strong>
            </a>
        </li>
        <li id="diagGGNItem">
            <a href="" id="diagGGNButton">
                the Gauss-Newton
                <strong>
                    Diagonal
                </strong>
            </a>
        </li>
        <li id="KFACItem">
            <a href="" id="KFACButton">
                the Gauss-Newton
                <strong>
                    KFAC
                </strong>
            </a>
        </li>
    </ul>
</center>

<!--- GRADIENT CODE --->
<div id="gradientCode">
<div class="language-python highlighter-rouge">
    <div class="highlight">
        <pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch

"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>


<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>


<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>

</code></pre>
    </div>
</div>
</div>

<!--- VARIANCE CODE --->
<div id="varianceCode" style="display:none;">
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch
and the variance with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">Variance</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>

<span style="color: blue;"><span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">extensions.Variance</span><span class="p">()):</span></span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>
    <span style="color: blue;"><span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">variance</span><span class="p">)</span></span>
</code></pre></div></div></div>


<!--- SECOND MOMENT CODE --->
<div id="secondMomentCode" style="display:none;">
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch
and the second moment with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">SumGradSquared</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>

<span style="color: blue;"><span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">SumGradSquared</span><span class="p">()):</span></span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>
    <span style="color: blue;"><span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">sum_grad_squared</span><span class="p">)</span></span>
</code></pre></div></div></div>

<!--- DIAGGGN CODE --->
<div id="diagGGNCode" style="display:none;">
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch
and the diagonal of the Gauss-Newton with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">DiagGGNExact</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>

<span style="color: blue;"><span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">extensions.DiagGGNExact</span><span class="p">()):</span></span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>
    <span style="color: blue;"><span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">diag_ggn_exact</span><span class="p">)</span></span>
</code></pre></div></div></div>

<!--- KFAC CODE --->
<div id="KFACCode" style="display:none;">
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="s">"""
Compute the gradient with Pytorch
and KFAC with BackPACK
"""</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">CrossEntropyLoss</span><span class="p">,</span> <span class="n">Linear</span>
<span class="kn">from</span> <span class="nn">backpack.utils.examples</span> <span class="kn">import</span> <span class="n">load_one_batch_mnist</span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack</span> <span class="kn">import</span> <span class="n">extend</span><span class="p">,</span> <span class="n">backpack</span></span>
<span style="color: blue;"><span class="kn">from</span> <span class="nn">backpack.extensions</span> <span class="kn">import</span> <span class="n">KFAC</span></span>

<span class="n">X</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">load_one_batch_mnist</span><span class="p">(flat=True)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">Linear</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span>
<span class="n">lossfunc</span> <span class="o">=</span> <span class="n"><span style="color: blue;">extend</span></span><span class="p">(</span><span class="n">CrossEntropyLoss</span><span class="p">())</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">lossfunc</span><span class="p">(</span><span class="n">model</span><span class="p">(</span><span class="n">X</span><span class="p">),</span> <span class="n">y</span><span class="p">)</span>

<span style="color: blue;"><span class="k">with</span> <span class="n">backpack</span><span class="p">(</span><span class="n">extensions.KFAC</span><span class="p">()):</span></span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
    <span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">grad</span><span class="p">)</span>
    <span style="color: blue;"><span class="k">print</span><span class="p">(</span><span class="n">param</span><span class="o">.</span><span class="n">kfac</span><span class="p">)</span></span>
</code></pre></div></div>
</div>

<script>
    window.onload = function() {
        var varianceButton = document.getElementById("varianceButton");
        var gradientButton = document.getElementById("gradientButton");
        var diagGGNButton = document.getElementById("diagGGNButton");
        var KFACButton = document.getElementById("KFACButton");

        var varianceItem = document.getElementById("varianceItem");
        var gradientItem = document.getElementById("gradientItem");
        var diagGGNItem = document.getElementById("diagGGNItem");
        var KFACItem = document.getElementById("KFACItem");

        var gradientCode = document.getElementById("gradientCode");
        var varianceCode = document.getElementById("varianceCode");
        var diagGGNCode = document.getElementById("diagGGNCode");
        var KFACCode = document.getElementById("KFACCode");

        var getCurrent = function() {
            if (varianceItem.className === "active") {
                return "variance"
            } else if (gradientItem.className === "active") {
                return "gradient"
            } else if (diagGGNItem.className === "active") {
                return "diagGGN"
            } else if (KFACItem.className === "active") {
                return "KFAC"
            }
h        }

        var setNewCurrent = function(name) {
            varianceItem.className = ""
            gradientItem.className = ""
            diagGGNItem.className = ""
            KFACItem.className = ""

            gradientCode.style.display = "none";
            varianceCode.style.display = "none";
            diagGGNCode.style.display = "none";
            KFACCode.style.display = "none";

            if (name === "variance") {
                varianceItem.className = "active"
                varianceCode.style.display = "block";
            } else if (name === "gradient") {
                gradientItem.className = "active"
                gradientCode.style.display = "block";
            } else if (name === "diagGGN") {
                diagGGNItem.className = "active"
                diagGGNCode.style.display = "block";
            } else if (name === "KFAC") {
                KFACItem.className = "active"
                KFACCode.style.display = "block";
            } else {
                gradientCode.style.display = "block";
            }
        }

        gradientButton.onclick = function() {
            setNewCurrent("gradient")
            return false
        }
        varianceButton.onclick = function() {
            setNewCurrent("variance")
            return false
        }
        diagGGNButton.onclick = function() {
            setNewCurrent("diagGGN")
            return false
        }
        KFACButton.onclick = function() {
            setNewCurrent("KFAC")
            return false
        }
    }
</script>
